# -*- coding: utf-8 -*-
# @Time    : 2022/2/5 9:45 上午
# @Author  : DIRICHLET
# @Email   : 511827625@qq.com
# @File    : cvaeExample1.py
# @Software: PyCharm
# @ Better Late Than Never!
# -*- coding: utf-8 -*-
"""
Created on Thu May 31 15:34:08 2018

@author: zy
"""

'''
条件变分自编码
'''

#过程式编程

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST-data', one_hot=True)
#自动下载获取数据集


print(type(mnist))  # <class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>

#数据集的每一张图片都是28*28的
print('Training data shape:', mnist.train.images.shape)  # Training data shape: (55000, 784)
print('Test data shape:', mnist.test.images.shape)  # Test data shape: (10000, 784)
print('Validation data shape:', mnist.validation.images.shape)  # Validation data shape: (5000, 784)
print('Training label shape:', mnist.train.labels.shape)  # Training label shape: (55000, 10)

train_X = mnist.train.images
train_Y = mnist.train.labels
test_X = mnist.test.images
test_Y = mnist.test.labels

'''
定义网络参数
'''
n_input = 784
n_hidden_1 = 256
n_hidden_2 = 2
n_classes = 10
learning_rate = 0.001
training_epochs = 20  # 迭代轮数
batch_size = 128  # 小批量数量大小
display_epoch = 3
show_num = 10

x = tf.placeholder(dtype=tf.float32, shape=[None, n_input])
y = tf.placeholder(dtype=tf.float32, shape=[None, n_classes])
# 后面通过它输入分布数据，用来生成模拟样本数据
zinput = tf.placeholder(dtype=tf.float32, shape=[None, n_hidden_2])

'''
定义学习参数
'''
weights = {
    'w1': tf.Variable(tf.truncated_normal([n_input, n_hidden_1], stddev=0.001)),
    'w_lab1': tf.Variable(tf.truncated_normal([n_classes, n_hidden_1], stddev=0.001)),
    'mean_w1': tf.Variable(tf.truncated_normal([n_hidden_1 * 2, n_hidden_2], stddev=0.001)),
    'log_sigma_w1': tf.Variable(tf.truncated_normal([n_hidden_1 * 2, n_hidden_2], stddev=0.001)),
    'w2': tf.Variable(tf.truncated_normal([n_hidden_2 + n_classes, n_hidden_1], stddev=0.001)),
    'w3': tf.Variable(tf.truncated_normal([n_hidden_1, n_input], stddev=0.001))
}

biases = {
    'b1': tf.Variable(tf.zeros([n_hidden_1])),
    'b_lab1': tf.Variable(tf.zeros([n_hidden_1])),
    'mean_b1': tf.Variable(tf.zeros([n_hidden_2])),
    'log_sigma_b1': tf.Variable(tf.zeros([n_hidden_2])),
    'b2': tf.Variable(tf.zeros([n_hidden_1])),
    'b3': tf.Variable(tf.zeros([n_input]))
}

'''
定义网络结构
'''
# 第一个全连接层是由784个维度的输入样->256个维度的输出
h1 = tf.nn.relu(tf.add(tf.matmul(x, weights['w1']), biases['b1']))
# 输入标签
h_lab1 = tf.nn.relu(tf.add(tf.matmul(y, weights['w_lab1']), biases['b_lab1']))
# 合并
hall1 = tf.concat([h1, h_lab1], 1)

# 第二个全连接层并列了两个输出网络
z_mean = tf.add(tf.matmul(hall1, weights['mean_w1']), biases['mean_b1'])
z_log_sigma_sq = tf.add(tf.matmul(hall1, weights['log_sigma_w1']), biases['log_sigma_b1'])
#此处不是直接方差，带了 log(sigma^2)

# 然后将两个输出通过一个公式的计算，输入到以一个2节点为开始的解码部分 高斯分布样本
eps = tf.random_normal(tf.stack([tf.shape(h1)[0], n_hidden_2]), 0, 1, dtype=tf.float32)
#标准正态分布
z = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)), eps))
# 重参数技巧 来源于正态分布
# 此处生成中间变量z

# 合并
zall = tf.concat([z, y], 1)  # None x 12
#条件体现在此处


# 解码器 由12个维度的输入->256个维度的输出
h2 = tf.nn.relu(tf.matmul(zall, weights['w2']) + biases['b2'])
# 解码器 由256个维度的输入->784个维度的输出  即还原成原始输入数据
reconstruction = tf.matmul(h2, weights['w3']) + biases['b3']

# 这两个节点不属于训练中的结构，是为了生成指定数据时用的
zinputall = tf.concat([zinput, y], 1)
h2out = tf.nn.relu(tf.matmul(zinputall, weights['w2']) + biases['b2'])
reconstructionout = tf.matmul(h2out, weights['w3']) + biases['b3']

'''
构建模型的反向传播
'''
# 计算重建loss
# 计算原始数据和重构数据之间的损失，这里除了使用平方差代价函数，也可以使用交叉熵代价函数
reconstr_loss = 0.5 * tf.reduce_sum((reconstruction - x) ** 2)
print(reconstr_loss.shape)  # (,) 标量
# 使用KL离散度的公式
latent_loss = -0.5 * tf.reduce_sum(1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq), 1)
print(latent_loss.shape)  # (128,)
cost = tf.reduce_mean(reconstr_loss + latent_loss)

# 定义优化器
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

num_batch = int(np.ceil(mnist.train.num_examples / batch_size))

'''
开始训练
'''
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print('开始训练')
    for epoch in range(training_epochs):
        total_cost = 0.0
        for i in range(num_batch):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            _, loss = sess.run([optimizer, cost], feed_dict={x: batch_x, y: batch_y})
            total_cost += loss

        # 打印信息
        if epoch % display_epoch == 0:
            print('Epoch {}/{}  average cost {:.9f}'.format(epoch + 1, training_epochs, total_cost / num_batch))

    print('训练完成')

    # 测试 此处测试采用给定xy 进行图像输出
    print('Result:', cost.eval({x: mnist.test.images, y: mnist.test.labels}))
    # 数据可视化   根据原始图片生成自编码数据
    reconstruction = sess.run(reconstruction,
                              feed_dict={x: mnist.test.images[:show_num], y: mnist.test.labels[:show_num]})
    plt.figure(figsize=(1.0 * show_num, 1 * 2))
    for i in range(show_num):
        # 原始图像
        plt.subplot(2, show_num, i + 1)
        plt.imshow(np.reshape(mnist.test.images[i], (28, 28)), cmap='gray')
        plt.axis('off')

        # 变分自编码器重构图像
        plt.subplot(2, show_num, i + show_num + 1)
        plt.imshow(np.reshape(reconstruction[i], (28, 28)), cmap='gray')
        plt.axis('off')
    plt.show()


    '''
    高斯分布取样，根据标签生成模拟数据
    '''
    #此处不用x 给定随机z与y 生成图像
    z_sample = np.random.randn(show_num, 2)
    reconstruction = sess.run(reconstructionout, feed_dict={zinput: z_sample, y: mnist.test.labels[:show_num]})
    plt.figure(figsize=(1.0 * show_num, 1 * 2))
    for i in range(show_num):
        # 原始图像
        plt.subplot(2, show_num, i + 1)
        plt.imshow(np.reshape(mnist.test.images[i], (28, 28)), cmap='gray')
        plt.axis('off')

        # 根据标签成成模拟数据
        plt.subplot(2, show_num, i + show_num + 1)
        plt.imshow(np.reshape(reconstruction[i], (28, 28)), cmap='gray')
        plt.axis('off')
    plt.show()